{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CosFace: Large Margin Cosine Loss for Deep Face Recognition\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Framework\n",
    "![image](images/overall.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## original softmax\n",
    "![image](images/orignal_softmax.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Normalize the weight and feature\n",
    "![title](images/norm_f.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Normalized version of Softmax Loss (NSL)\n",
    "![title](images/norm_softmax_loss_f.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Large Margin Cosine Loss (LMCL)\n",
    "![title](images/large_margin_loss_new2.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Py_func implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def py_func(func, inp, Tout, stateful = True, name=None, grad_func=None):\n",
    "    # Need to generate a unique name to avoid duplicates\n",
    "    rand_name = 'PyFuncGrad' + str(np.random.randint(0,1E+8))\n",
    "    tf.RegisterGradient(rand_name)(grad_func)\n",
    "    g = tf.get_default_graph()\n",
    "    with g.gradient_override_map({'PyFunc':rand_name}):\n",
    "        return tf.py_func(func,inp,Tout,stateful=stateful, name=name)\n",
    "\n",
    "\n",
    "def coco_forward(xw, y, m, name=None):\n",
    "    num = len(y)\n",
    "    orig_ind = range(num)\n",
    "    xw[orig_ind,y] -= m\n",
    "    return xw\n",
    "\n",
    "def coco_help(grad,y):\n",
    "    grad_copy = grad.copy()\n",
    "    return grad_copy\n",
    "\n",
    "def coco_backward(op, grad):\n",
    "    \n",
    "    y = op.inputs[1]\n",
    "    m = op.inputs[2]\n",
    "    grad_copy = tf.py_func(coco_help,[grad,y],tf.float32)\n",
    "    return grad_copy,y,m\n",
    "\n",
    "def coco_func(xw,y,m, name=None):\n",
    "    with tf.op_scope([xw,y,m],name,\"Coco_func\") as name:\n",
    "        coco_out = py_func(coco_forward,[xw,y,m],tf.float32,name=name,grad_func=coco_backward)\n",
    "        return coco_out\n",
    "    \n",
    "# function interface\n",
    "def cosine_loss(x,y,num_cls,reuse=False,alpha=0.25,scale=64,name='cosine_loss'):\n",
    "    '''\n",
    "    x: B x D - features\n",
    "    y: B x 1 - labels\n",
    "    num_cls: 1 - total class number\n",
    "    alpah: 1 - margin\n",
    "    scale: 1 - scaling paramter\n",
    "    '''\n",
    "    # define the classifier weights\n",
    "    xs = x.get_shape()\n",
    "    with tf.variable_scope('centers_var',reuse=reuse) as center_scope:\n",
    "        w = tf.get_variable(\"centers\", [xs[1], num_cls], dtype=tf.float32, \n",
    "            initializer=tf.contrib.layers.xavier_initializer(),trainable=True)\n",
    "   \n",
    "    #normalize the feature and weight\n",
    "    #(N,D)\n",
    "    x_feat_norm = tf.nn.l2_normalize(x,1,1e-10)\n",
    "    #(D,C)\n",
    "    w_feat_norm = tf.nn.l2_normalize(w,0,1e-10)\n",
    "    \n",
    "    # get the scores after normalization \n",
    "    #(N,C)\n",
    "    xw_norm = tf.matmul(x_feat_norm, w_feat_norm)  \n",
    "    #value = tf.identity(xw)\n",
    "    #substract the marigin and scale it\n",
    "    value = coco_func(xw_norm,y,alpha) * scale\n",
    "    \n",
    "    # compute the loss as softmax loss\n",
    "    cos_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=value))\n",
    "\n",
    "    return cos_loss "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tf api implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cos_loss(x, y,  num_cls, reuse=False, alpha=0.25, scale=64,name = 'cos_loss'):\n",
    "    '''\n",
    "    x: B x D - features\n",
    "    y: B x 1 - labels\n",
    "    num_cls: 1 - total class number\n",
    "    alpah: 1 - margin\n",
    "    scale: 1 - scaling paramter\n",
    "    '''\n",
    "    # define the classifier weights\n",
    "    xs = x.get_shape()\n",
    "    with tf.variable_scope('centers_var',reuse=reuse) as center_scope:\n",
    "        w = tf.get_variable(\"centers\", [xs[1], num_cls], dtype=tf.float32, \n",
    "            initializer=tf.contrib.layers.xavier_initializer(),trainable=True)\n",
    "   \n",
    "    #normalize the feature and weight\n",
    "    #(N,D)\n",
    "    x_feat_norm = tf.nn.l2_normalize(x,1,1e-10)\n",
    "    #(D,C)\n",
    "    w_feat_norm = tf.nn.l2_normalize(w,0,1e-10)\n",
    "    \n",
    "    # get the scores after normalization \n",
    "    #(N,C)\n",
    "    xw_norm = tf.matmul(x_feat_norm, w_feat_norm)  \n",
    "    #implemented by py_func\n",
    "    #value = tf.identity(xw)\n",
    "    #substract the marigin and scale it\n",
    "    #value = coco_func(xw_norm,y,alpha) * scale\n",
    "\n",
    "    #implemented by tf api\n",
    "    margin_xw_norm = xw_norm - alpha\n",
    "    label_onehot = tf.one_hot(y,num_cls)\n",
    "    value = scale*tf.where(tf.equal(label_onehot,1), margin_xw_norm, xw_norm)\n",
    "\n",
    "    \n",
    "    # compute the loss as softmax loss\n",
    "    cos_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=value))\n",
    "\n",
    "    return cos_loss \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
